import numpy as np

class GenData(object):

    def __init__(self, params):

        self.batch_size = params['BATCH_SIZE']
        self.Nr = params['Nr']
        self.Nt = params['Nt']
        self.constellation = params['CONSTELLATION']
        self.constellation_complex = params['CONSTELLATION_COMPLEX']
        self.sqrt2 = np.sqrt(2)

    def rayleigh_varying(self,snr,noisefactor):
        '''
        H ~ CN(0,I)
        n ~ CN(0,sigma2I)
        '''
        s = np.random.randint(low=0, high=np.shape(self.constellation_complex)[0], size=[self.batch_size, self.Nt])
        x = self.constellation_complex[s]
        x_real = np.concatenate((np.real(x),np.imag(x)),1)

        Hr = np.random.randn(self.batch_size, self.Nr, self.Nt) / self.sqrt2
        Hi = np.random.randn(self.batch_size, self.Nr, self.Nt) / self.sqrt2
        H_real = np.concatenate([np.concatenate([Hr, -Hi], axis=2), np.concatenate([Hi, Hr], axis=2)], axis=1)
        HtH = np.matmul(np.transpose(H_real,[0, 2, 1]), H_real)

        factor = np.trace(HtH,axis1=1,axis2=2) / (2 * self.Nr)
        sigma2div2 = np.expand_dims(factor,-1) / snr
        noise = np.sqrt(sigma2div2) * np.random.randn(self.batch_size, 2 * self.Nr)
        sigma2 = 2 * sigma2div2

        y_real = self.batch_matvec_mul(H_real, x_real)

        y_real = y_real + noise if noisefactor else y_real

        return x_real, H_real, y_real, sigma2

    def batch_matvec_mul(self, A, b):
        """
        矩阵A与矩阵b相乘，其中A.shape=(batch_size, Nr, Nt)
        b.shape = (batch_size, Nt)
        输出矩阵C，C.shape = (batch_size, Nr)
        """
        C = np.matmul(A, np.expand_dims(b, axis=2))

        return np.squeeze(C, -1)